import pandas as pd
from transformers import TrainingArguments, Trainer
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import argparse


parser = argparse.ArgumentParser()
parser.add_argument('--epochs', default=3, type=int)
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--lr', default=5e-4, type=float)
parser.add_argument('--optimizer', default='adamW', choices=['sgd', 'adam', 'lamb', 'adamW', 'ALTO','adaBelief'])
parser.add_argument('--beta', default=0.9, type=float)
args = parser.parse_args()
#################################################################################################################################################
# bert and dataset
#################################################################################################################################################
tokenizer = BertTokenizer.from_pretrained("....../bert-base-cased")
model = BertForSequenceClassification.from_pretrained("....../bert-base-cased", num_labels=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Let's use", torch.cuda.device_count(), "GPUs!")
model.to(device)


train_df = pd.read_parquet('....../glue-MRPC/train-00000-of-00001.parquet')
valid_df = pd.read_parquet('....../glue-MRPC/validation-00000-of-00001.parquet')
test_df = pd.read_parquet('....../glue-MRPC/test-00000-of-00001.parquet')

def preprocess_dataset(dataset):
    tokenized_inputs = {'input_ids': [], 'token_type_ids': [], 'attention_mask': []}
    
    for _, row in dataset.iterrows():
        tokenized = tokenizer(row['sentence1'], row['sentence2'], truncation=True, padding='max_length', max_length=128)
        tokenized_inputs['input_ids'].append(tokenized['input_ids'])
        tokenized_inputs['token_type_ids'].append(tokenized['token_type_ids'])
        tokenized_inputs['attention_mask'].append(tokenized['attention_mask'])
    
    return tokenized_inputs

train_tokenized = preprocess_dataset(train_df)
valid_tokenized = preprocess_dataset(valid_df)
test_tokenized = preprocess_dataset(test_df)

class MRPCDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = MRPCDataset(train_tokenized, train_df['label'].tolist())
valid_dataset = MRPCDataset(valid_tokenized, valid_df['label'].tolist())
test_dataset = MRPCDataset(test_tokenized, test_df['label'].tolist())

train_samples = len(train_dataset)
print("Number of training samples in MRPC:", train_samples)

#################################################################################################################################################
# train and eval
#################################################################################################################################################
from torch.optim import Adam, SGD, AdamW
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from optimizers.lamb import create_lamb_optimizer
from optimizers.ALTO import create_ALTO_optimizer
from adabelief_pytorch import AdaBelief

learning_rate = args.lr
# 定义自定义优化器

if args.optimizer == 'sgd':
    optimizer = SGD(model.parameters(), lr=learning_rate)
elif args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=learning_rate)
elif args.optimizer == 'adamW':
    optimizer = AdamW(model.parameters(), lr=learning_rate)
elif args.optimizer == 'adaBelief':
    optimizer = AdaBelief(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
elif args.optimizer == 'ALTO':
    optimizer = create_ALTO_optimizer(model, lr=learning_rate, betas=(args.beta, 0.9, 0.99), weight_decay=1e-4)
elif args.optimizer == 'lamb':
    optimizer = create_lamb_optimizer(model, lr=learning_rate, weight_decay=1e-4)
else:
    raise ValueError('Unknown optimizer: {}'.format(args.optimizer))

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

print(args.batch_size//torch.cuda.device_count())

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    per_device_train_batch_size=args.batch_size//torch.cuda.device_count(),
    per_device_eval_batch_size=args.batch_size//torch.cuda.device_count(),
    num_train_epochs=args.epochs
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    compute_metrics=compute_metrics,  # compute_metrics
    optimizers=(optimizer, None)
)

print("start training...")

trainer.train()

test_results = trainer.evaluate(eval_dataset=test_dataset)
print("Test Results:", test_results)
